import ekf_testv6 as ekf
from ekf_testv6 import dates, zJ_from_W, zJtomm
import numpy as np
# import pykalmani
import matplotlib.pyplot as plt
import matplotlib        as mpl
import scipy.stats as stats
import matplotlib.transforms as mtransforms


# create colormap
nfrac=16
upper = mpl.cm.Greens(np.arange(256))
mycmap0 = np.vstack(( np.ones((4,4)),np.power(upper[:-16,:],2) ,[0,0,0,1]))
                    # convert to matplotlib colormap
mycmap = mpl.colors.ListedColormap(mycmap0, name='myColorMap', N=mycmap0.shape[0])
trail_averaging=False
vari_of_parameters=False
source_name="SSP_Nazarenko" #"SSP"
custlevelsl = -np.flip(np.arange(-2.25,3.001,0.25))*np.log(4)-np.log(25)+np.log(16)
SSPnames=[126,434,245,370,585]
#discard 434
gridspec = dict(wspace=0.4, width_ratios=[3.75,3.75, 0.45])
fig, axs= plt.subplots(1, 3, figsize=(8.25,6),gridspec_kw=gridspec)
plt.suptitle("Projected Surface Climate State")
fig2, axs2= plt.subplots(1, 3, figsize=(8.25,6),gridspec_kw=gridspec)
plt.suptitle("Projected Ocean Heat Content State")
#axs[1].set_visible(False)
#axs2[1].set_visible(False)
end_temps=[[],[]]
nsamps=6000 #5min per 1000, 6000 looks great
volcsplot=True
if(volcsplot):
  import VolcanoFit1 as VolcanoFit
  for rcpa in [0,1]:
      ax=axs[rcpa]
      ax2=axs2[rcpa]
      rcpcreate=[0,3]
      rcp=rcpcreate[rcpa]
      #rcp=rcpa*3 #+1# to make it go to 4 and 8
      ax.plot(dates,ekf.xh1s,'-',label='past EBM-KF GMST state estimate $\^{T }_{t}}$',color=ekf.colorekf)
      ax.fill_between(dates, ekf.xh1s-2*ekf.stdP, ekf.xh1s+2*ekf.stdP,label="95% CI ($\pm 2\sqrt{\hat{p}^T_t}$) of GMST state $\^{T }_{t}}$", color=ekf.colorstate, alpha=0.5)
      ax.plot(dates,ekf.temps,'o',label='HadCRUT5 measurements $Y_{t}$',markersize=2,color=ekf.colorgrey)
      
      ax2.plot(dates,ekf.xh1d*zJ_from_W,'-',label='past EBM-KF OHCA state estimate $\^{ H }_{t}}$', color=ekf.colorekf)
      ax2.fill_between(dates, ekf.xh1d*zJ_from_W-2*ekf.stdPd*zJ_from_W, ekf.xh1d*zJ_from_W+2*ekf.stdPd*zJ_from_W,label="95% CI ($\pm 2\sqrt{\hat{p}^H_t}$) of OHCA state $\^{ H }_{t}}$", color=ekf.colorstate, alpha=0.5)
      ax2.plot(dates,ekf.ocean_heat_measured,'o',label='Zanna (2019) measurements $\Psi_{t}$',markersize=2,color=ekf.colorgrey)

      data3 = np.genfromtxt(open("KF6projection"+source_name+".csv", "rb"),dtype=float, delimiter=',')
      startf=2023 #not 2024!
      endf=2100
      inum=4
      pdftable0=np.zeros((5001,(endf-startf)*inum))
      pdftable1=np.zeros((4501,(endf-startf)*inum))
      tppdftable0=np.zeros((5001,(endf-startf)*inum))
      tppdftable1=np.zeros((4501,(endf-startf)*inum))
      sampGMST=np.zeros((nsamps,(endf-startf+1)))
      sampOHCA=np.zeros((nsamps,(endf-startf+1)))
      mPs=np.zeros((2,(endf-startf+1)))
      mSs=np.zeros((2,(endf-startf+1)))
      heights0=np.linspace(286.5,291.5,5001)
      heights1=np.linspace(-1000,8000,4501)
      
      for i in range(0,nsamps):
          #this uses VolcanoFit
          if(vari_of_parameters):
              ekf.gad = np.random.normal(0.67,0.15)
              std_fdbkA=(1.1+0.08)/2/1.65
              ekf.fdbkA = np.random.normal(0.42,std_fdbkA)
              ekf.precompute_coeffs(False)
          (fxhat,fP0)=ekf.ekf_future(startf,ekf.Plastretain,ekf.xlastretain,endf+1,np.log10(data3[:,1+rcp]),data3[:,6+rcp],VolcanoFit,trail_averaging)
          stdfP=np.sqrt(np.abs(fP0))
          mPs=mPs+np.abs([fP0[:,0,0],fP0[:,1,1]])
          stdfS00=np.sqrt(np.abs(fP0[:,0,0]+ekf.Rtvara))
          stdfS11=np.sqrt(np.abs(fP0[:,1,1]+ekf.Roctvara))
          mSs=mSs+np.square([stdfS00,stdfS11])
          #print(2*stdfP[-1,0,0] ,2*stdfP[-1,1,1])
          end_temps[rcpa].append(fxhat[-1,:]) #what does this do? to calculate an end distribution in the IDLE
          sampGMST[i,:]=fxhat[:,0]
          sampOHCA[i,:]=fxhat[:,1]                
          if(i==0):
              ax.plot(range(startf,endf+1),fxhat[:,0],'-',color='k',lw=0.2,label="Samples from volc distn of EBM-KF GMST") #colorekf
              ax2.plot(range(startf,endf+1),fxhat[:,1]*zJ_from_W,'-',color='k',lw=0.2,label="Samples from volc distn of EBM-KF OHCA")
              (fxhatnv,fP0nv)=ekf.ekf_future(startf,ekf.Plastretain,ekf.xlastretain,endf+1,np.log10(data3[:,1+rcp]),data3[:,6+rcp], noVolcs=True)
              ax.plot(range(startf,endf+1),fxhatnv[:,0],'-',color='b',lw=0.6,label="Future EBM-KF GMST state, const. volc") #colorekf
              ax2.plot(range(startf,endf+1),fxhatnv[:,1]*zJ_from_W,'-',color='b',lw=0.6,label="Future EBM-KF OHCA state, const. volc")
          elif(i<10):
              ax.plot(range(startf,endf+1),fxhat[:,0],'-',color='k',lw=0.2)
              ax2.plot(range(startf,endf+1),fxhat[:,1]*zJ_from_W,'-',color='k',lw=0.2)
              #print(stdfP[0],stdfP[-2])
          elif(i%100==0):
              print(i/nsamps)
          for t in range(endf-startf):
              for interp in range(inum):
                  iloc0=(fxhat[t,0]*(1-interp/inum)+fxhat[t+1,0]*(interp/inum))
                  iscale0=(stdfP[t,0,0]*(1-interp/inum)+stdfP[t+1,0,0]*(interp/inum))
                  pdftable0[:,t*inum+interp]=pdftable0[:,t*inum+interp]+stats.norm.pdf(heights0,loc=iloc0,scale=iscale0)
                  iloc1=(fxhat[t,1]*(1-interp/inum)+fxhat[t+1,1]*(interp/inum))*zJ_from_W
                  iscale1=(stdfP[t,1,1]*(1-interp/inum)+stdfP[t+1,1,1]*(interp/inum))*zJ_from_W
                  pdftable1[:,t*inum+interp]=pdftable1[:,t*inum+interp]+stats.norm.pdf(heights1,loc=iloc1,scale=iscale1)
                  iloc0=(fxhat[t,0]*(1-interp/inum)+fxhat[t+1,0]*(interp/inum))
                  iscale0=(stdfS00[t]*(1-interp/inum)+stdfS00[t+1]*(interp/inum))
                  tppdftable0[:,t*inum+interp]=tppdftable0[:,t*inum+interp]+stats.norm.pdf(heights0,loc=iloc0,scale=iscale0)
                  iloc1=(fxhat[t,1]*(1-interp/inum)+fxhat[t+1,1]*(interp/inum))*zJ_from_W
                  iscale1=(stdfS11[t]*(1-interp/inum)+stdfS11[t+1]*(interp/inum))*zJ_from_W
                  tppdftable1[:,t*inum+interp]=tppdftable1[:,t*inum+interp]+stats.norm.pdf(heights1,loc=iloc1,scale=iscale1)

      #v_newmax=max(pdftable[:,int((endf-startf)*inum/2)])/nsamps/100
      mPs=mPs/nsamps;
      mSs=mSs/nsamps;
      if(False):
          nsamps=6000 %6000
          fdata = np.load("SSP"+"{:.0f}".format(SSPnames[rcp])+".npz")
          cdftable0=fdata['cgmst']
          pdftable0=cdftable0.copy()
          pdftable0[1:,:] -= cdftable0[:-1,:]
          pdftable0=pdftable0*nsamps
          cdftable1=fdata['cohca']
          pdftable1=cdftable1.copy()
          pdftable1[1:,:] -= cdftable1[:-1,:]
          pdftable1=pdftable1*nsamps
##          sampGMST=fdata['sgmst']
##          sampOHCA=fdata['sohca']
##          mPs=fdata['mps']
##          mSs=fdata['mss']
      else:
          cdftable0=np.cumsum(pdftable0/nsamps,axis=0)
          cdftable1=np.cumsum(pdftable1/nsamps,axis=0)
          tpcdftable0=np.cumsum(tppdftable0/nsamps,axis=0)
          tpcdftable1=np.cumsum(tppdftable1/nsamps,axis=0)
          if(trail_averaging):
              outfile=("SSP"+"{:.0f}ta".format(SSPnames[rcp]))
          elif(vari_of_parameters):
              outfile=("SSP"+"{:.0f}varpara".format(SSPnames[rcp]))
          elif(source_name != "SSP"):
              outfile=("SSP"+"{:.0f}".format(SSPnames[rcp])+source_name)
          else:
              outfile=("SSP"+"{:.0f}".format(SSPnames[rcp]))
          np.savez_compressed(outfile,gmstheights=heights0,ohcaheights=heights1, cgmst=cdftable0,cohca=cdftable1,
                              tpgmst=tpcdftable0,tpohca=tpcdftable1,sgmst=sampGMST, soca=sampOHCA, mps=mPs,  mss=mSs)
          
      CS=ax.contourf(np.arange(startf,endf,1/inum),heights0,np.log(pdftable0/nsamps*2/5+1e-20),levels=custlevelsl,cmap=mycmap,extend='both')
      l25cdftable0=np.absolute(cdftable0-25)
      l975cdftable0=np.absolute(cdftable0-975)
      ax.plot(np.arange(startf,endf,1/inum),heights0[np.argmin(l25cdftable0,axis=0)],'-',color=ekf.pcolor)
      ax.plot(np.arange(startf,endf,1/inum),heights0[np.argmin(l975cdftable0,axis=0)],'-',color=ekf.pcolor,label="2.5-97.5% CI of volc EBM-KF GMST")
      CS2=ax2.contourf(np.arange(startf,endf,1/inum),heights1,np.log(pdftable1/nsamps*250+1e-20),levels=custlevelsl,cmap=mycmap, extend='both')
      
      l25cdftable1=np.absolute(cdftable1-0.5*0.025)
      l975cdftable1=np.absolute(cdftable1-0.5*0.975)
      ax2.plot(np.arange(startf,endf,1/inum),heights1[np.argmin(l25cdftable1,axis=0)],'-',color=ekf.pcolor)
      ax2.plot(np.arange(startf,endf,1/inum),heights1[np.argmin(l975cdftable1,axis=0)],'-',color=ekf.pcolor,label="2.5-97.5% CI of volc EBM-KF OHCA")
      print("Maxiumum Densities")
      print(np.max(pdftable0/nsamps/10))
      print(np.max(pdftable1/nsamps/10))
      #CS.set_clim(custlevelsl[0],custlevelsl[-1]) 
      #CS2.set_clim(custlevelsl[0],custlevelsl[-1])


    
    #  plt.fill_between(range(2022,2100), fxhat-stdfP, fxhat+stdfP,label="68% CI of EBM-KF state ($\sqrt{ P_{n}}$)", color=colorstate, alpha=0.5)
      ax.set_xlabel('Year')
      
      ax.tick_params( direction = 'in',bottom=True, top=True, left=True, right=True )
      ax.set_xlim([1990,2100])
      ax.set_ylim(286.9,291.5)
      ax.set_yticks(np.arange(287,291.5,0.4))
      axb = ax.twinx()
      mn, mx = ax.get_ylim()
      axb.set_yticks(np.arange(14,18.8,0.4))
      axb.set_ylim(mn- 273.15, mx- 273.15)
      ax2.set_xlabel('Year')
      
      ax2.tick_params( direction = 'in',bottom=True, top=True, left=True, right=True )
      ax2.set_xlim([1990,2100])
      ax2.set_yticks(np.arange(0,3250,250))
      ax2.set_ylim([250,3000])

      ax4b = ax2.twinx()
      mn, mx = ax2.get_ylim()
      ax4b.set_yticks(np.arange(0,40,2))
      ax4b.set_ylim(mn*zJtomm/10, mx*zJtomm/10)
      if(rcpa==0):
          ax.set_ylabel('Temperature (K)')
          ax2.set_ylabel('Heat (ZJ)')
      else:
          axb.set_ylabel('Temperature (°C)')
          ax4b.set_ylabel('Thermosteric Sea Level Rise (cm)')

      if(rcpa==0):
          ax.legend(loc="best",prop={'size': 7})
          ax2.legend(loc="best",prop={'size': 7})
      axlabels=['a)','b)','c)','d)']
      label=axlabels[rcpa]
      ax.set_title(axlabels[rcpa]+"        SSP"+"{:.0f}".format(SSPnames[rcp])+" Future Projection        ") #"{:.1f}".format(
      ax2.set_title(axlabels[rcpa]+"        SSSP"+"{:.0f}".format(SSPnames[rcp])+" Future Projection        ")
      
      # label physical distance to the left and up:
      trans = mtransforms.ScaledTranslation(-(40-4*rcpa)/72, 15/72, fig.dpi_scale_trans) #-20/72, 7/72
      #ax.text(0.0, 1.0, label, transform=ax.transAxes + trans, fontsize='large', va='bottom')
      #ax2.text(0.0, 1.0, label, transform=ax.transAxes + trans, fontsize='large', va='bottom')
  cbar = fig.colorbar(CS,cax=axs[2],ticks=custlevelsl[0::2])
  cbar2 = fig.colorbar(CS2,cax=axs2[2],ticks=custlevelsl[0::2])
  cbar.ax.set_ylabel('Probability Density per 0.4K')
  cbar2.ax.set_ylabel('Probability Density per 250ZJ')
  cbar.ax.set_yticklabels(np.round(np.exp(custlevelsl[0::2]),8))
  cbar2.ax.set_yticklabels(np.round(np.exp(custlevelsl[0::2]),8))

plt.show()

fig.savefig("futGMST"+source_name+".pdf", format = "pdf")
#fig.savefig("futGMSTsc.png", dpi=400,format="png")
fig2.savefig("futOHCA"+source_name+".pdf", format = "pdf")
#fig2.savefig("futOHCAsc.png", dpi=400,format="png")
